package routine

import (
	"sync"
)

// AtomicRoutine 原子性协程
type AtomicRoutine interface {

	// OnAtomicPrepare 执行原子操作前的准备阶段，用于准备资源，不进行具有修改性质的执行操作
	//  @return error
	OnAtomicPrepare() error

	// Commit 执行具有修改性质的操作
	//  @return error
	Commit() error

	// Rollback 回滚到Commit()之前的状态（可以在OnAtomicPrepare()中存储原始状态）
	Rollback()

	// OnAtomicEnd 执行原子操作后的结束阶段，用于释放资源，不进行具有修改性质的执行操作
	OnAtomicEnd()
}

// AtomicRoutineBatch 原子协程池，要么所有协程都执行，要么发生错误所有协程都回滚
type AtomicRoutineBatch struct {

	// atomicRoutines 原子协程数组（首字母为小写，限制只能通过Add（）添加有序协程）
	atomicRoutines []AtomicRoutine
}

func (batch *AtomicRoutineBatch) Add(atomicRoutine AtomicRoutine) *AtomicRoutineBatch{
	batch.atomicRoutines = append(batch.atomicRoutines, atomicRoutine)
	return batch
}

// Execute 首先并发执行所有协程的OnAtomicPrepare()，发生错误直接挨个OnAtomicEnd()
// 所有协程的OnAtomicPrepare()均执行完毕且没错误产生后，并发执行所有协程的Commit()
// 所有协程的Commit()均执行完毕且没错误产生后，并发执行所有协程的OnAtomicEnd()；但凡有一个协程出错，并发执行所有协程的Rollback()和OnAtomicEnd()
//
//	@receiver batch
//	@return bool 是否有错误发生
func (batch *AtomicRoutineBatch) Execute() bool {
	errOccured1 := batch.prepare()
	if errOccured1 {
		batch.end()
		return false
	}

	errOccured2 := batch.commit()
	if errOccured2 {
		batch.rollbackAndEnd()
		return false
	}

	batch.end()
	return true
}

func (batch *AtomicRoutineBatch) prepare() bool {
	num := len(batch.atomicRoutines)
	wg := new(sync.WaitGroup)
	wg.Add(num)

	errOccured := false
	errMutex := new(sync.RWMutex)

	for i := 0; i < num; i++ {
		go func(val int) {
			defer wg.Done()

			errMutex.RLock()
			proceed := !errOccured
			errMutex.RUnlock()
			if !proceed {
				return
			}

			atomicRoutine := batch.atomicRoutines[val]
			err := atomicRoutine.OnAtomicPrepare()
			if err != nil {
				errMutex.Lock()
				errOccured = true
				errMutex.Unlock()
				return
			}

		}(i)
	}

	wg.Wait()
	return errOccured
}

func (batch *AtomicRoutineBatch) commit() bool {
	num := len(batch.atomicRoutines)
	wg := new(sync.WaitGroup)
	wg.Add(num)

	errOccured := false
	errMutex := new(sync.RWMutex)

	for i := 0; i < num; i++ {
		go func(val int) {
			defer wg.Done()

			errMutex.RLock()
			proceed := !errOccured
			errMutex.RUnlock()
			if !proceed {
				return
			}

			atomicRoutine := batch.atomicRoutines[val]
			err := atomicRoutine.Commit()
			if err != nil {
				errMutex.Lock()
				errOccured = true
				errMutex.Unlock()
				return
			}

		}(i)
	}

	wg.Wait()
	return errOccured
}

func (batch *AtomicRoutineBatch) rollbackAndEnd() {

	for i := 0; i < len(batch.atomicRoutines); i++ {
		go func(val int) {

			atomicRoutine := batch.atomicRoutines[val]
			atomicRoutine.Rollback()
			atomicRoutine.OnAtomicEnd()

		}(i)
	}

}

func (batch *AtomicRoutineBatch) end() {
	for i := 0; i < len(batch.atomicRoutines); i++ {
		go func(val int) {

			atomicRoutine := batch.atomicRoutines[val]
			atomicRoutine.OnAtomicEnd()

		}(i)
	}
}
